{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.72077274, -0.69317144, -0.7466632 ], dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "    def __init__(self):\n",
    "        env = gym.make('Pendulum-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        done = terminated or truncated\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            done = True\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGiCAYAAABd6zmYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAkcElEQVR4nO3dfXRU9YH/8c/kaSAPM0mAzJCSCCot5vBQBYWpum5LlmizrQ/Yo5TSLGX1wAYOiOtWrOK223PCwa1arWC7dsXtKmyxgisrtTkBo9bIQyQ1gES7okmFSZSYmYBk8jDf3x8s83MkahJmMt+E9+ucOcfce+c737ml8869c2fiMMYYAQBgoaRETwAAgM9CpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1kpYpB555BGNHz9eI0aM0MyZM7V79+5ETQUAYKmEROq//uu/tHLlSt177716/fXXNW3aNJWUlKilpSUR0wEAWMqRiC+YnTlzpi699FL94he/kCSFw2EVFBRo2bJluvPOOwd7OgAAS6UM9gN2dnaqtrZWq1atiixLSkpScXGxampqer1PKBRSKBSK/BwOh9Xa2qpRo0bJ4XDEfc4AgNgyxqi9vV35+flKSvrsk3qDHqkPP/xQPT098ng8Ucs9Ho8OHTrU630qKir04x//eDCmBwAYRE1NTRo3btxnrh/0SA3EqlWrtHLlysjPgUBAhYWFampqksvlSuDMAAADEQwGVVBQoKysrM/dbtAjNXr0aCUnJ6u5uTlqeXNzs7xeb6/3cTqdcjqdZyx3uVxECgCGsC96y2bQr+5LS0vT9OnTVVVVFVkWDodVVVUln8832NMBAFgsIaf7Vq5cqbKyMs2YMUOXXXaZHnzwQZ04cUILFy5MxHQAAJZKSKRuuukmffDBB1q9erX8fr+++tWv6ve///0ZF1MAAM5tCfmc1NkKBoNyu90KBAK8JwUAQ1BfX8f57j4AgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1up3pF566SV961vfUn5+vhwOh7Zu3Rq13hij1atXa+zYsRo5cqSKi4v19ttvR23T2tqq+fPny+VyKTs7W4sWLdLx48fP6okAAIaffkfqxIkTmjZtmh555JFe169du1YPPfSQHn30Ue3atUsZGRkqKSlRR0dHZJv58+frwIEDqqys1LZt2/TSSy/p1ltvHfizAAAMT+YsSDJbtmyJ/BwOh43X6zX33XdfZFlbW5txOp1m48aNxhhjDh48aCSZPXv2RLbZvn27cTgc5v333+/T4wYCASPJBAKBs5k+ACBB+vo6HtP3pA4fPiy/36/i4uLIMrfbrZkzZ6qmpkaSVFNTo+zsbM2YMSOyTXFxsZKSkrRr165exw2FQgoGg1E3AMDwF9NI+f1+SZLH44la7vF4Iuv8fr/y8vKi1qekpCg3NzeyzadVVFTI7XZHbgUFBbGcNgDAUkPi6r5Vq1YpEAhEbk1NTYmeEgBgEMQ0Ul6vV5LU3Nwctby5uTmyzuv1qqWlJWp9d3e3WltbI9t8mtPplMvliroBAIa/mEZqwoQJ8nq9qqqqiiwLBoPatWuXfD6fJMnn86mtrU21tbWRbXbs2KFwOKyZM2fGcjoAgCEupb93OH78uP785z9Hfj58+LDq6uqUm5urwsJCrVixQj/96U81ceJETZgwQffcc4/y8/N13XXXSZIuuugiXX311brlllv06KOPqqurS0uXLtXNN9+s/Pz8mD0xAMAw0N/LBnfu3GkknXErKyszxpy6DP2ee+4xHo/HOJ1OM3v2bNPQ0BA1xrFjx8y8efNMZmamcblcZuHChaa9vT3mly4CAOzU19dxhzHGJLCRAxIMBuV2uxUIBHh/CgCGoL6+jg+Jq/sAAOcmIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsFZKoicA4NxhenrU1dqqQG2tgn/6k0JHjkjGKHX0aLmmTpV7xgyleTxypKTI4XAkerqwAJECMCi6AgF99Mc/yv/00+r68MOodSfffVfBvXvlf/ppjbnmGo0uKVHa6NEJmilsQqQAxF2opUVNv/yl2vfvV/jkyc/crjsY1NGnn1aouVn58+fL6fEM4ixhIyIFIK66PvpIb99zj0JHj/btDj09an35ZYW7u3XekiVKycqK7wRhNS6cABA3J5ua9PY//3PfA3VaT4/aXnlFLdu2xWdiGDKIFIC46Dx2TO/ef79OHj484DE+rKxUZ2trDGeFoYZIAYi5zmPH9NZdd+nj//3fsxqn69gxNW/ZEqNZYSjiPSkAMXWysVHvPvBA/0/x9cYYhUOhsx8HQxaRAhAzXR99pMM/+9lZneIDPolIAYiJUEtL/67iA/qASAE4ayebmvTugw8SKMQckQJwVroCAb2zdq063nsv5mMnZ2Up5/LLYz4uhg4iBWDAuo8f11s/+pE6GhvjMn7GhRcqa+rUuIyNoYFIARiQ7vZ2NW/dqs7m5riM7xw3TuP+/u/5otlzHJ+TAtBv4c5OtTz3nFq2bYvLJeLOsWM1rqxMI/LzYz42hhaOpAD0S8/Jk2p57jkd/e1vpXA45uOnX3CBvrRwobKmTOEoCkQKQN/1dHTogxdekP+ZZ+ISqKypU/WlhQuVfv75BAqSiBSAPgp3d+tYVZWObtqk8Mcfx3z8zKIiFdx6q0YUFBAoRBApAF8o3NmpYzt3qumxx6SenpiPn/7lL+u85cvl9HoJFKJw4QSAzxXu7lbryy/ryG9+E59AXXCBCm65hUChVxxJAfhMJhxWYNcu/eXf/k09cTjFN3L8eI1btEgZEycSKPSKSAHolQmH1bZ7tw7ff79MV1fMx3ekpqpg8WJlXnQRgcJn4nQfgDMYYxSorVXT+vVxCZQkeb/zHWVOmkSg8Lk4kgIQxRij9vp6Hf7Zz+JyFZ8jNVVjb7pJnuuvlyOJ35Px+YgUgCjBffv0TkVFfP7YoMOh/Hnz5L3xxtiPjWGJSAGQdOoIKrhvn9576KG4/TXcsfPmyXP99XEZG8MTkQIQOcX3zpo1Cnd0xHx8R2qqxt5886lTfMnJMR8fwxeRAqDg66/rnfvui0ugJGnszTdr7He+E5exMbwRKeAcZoxRe12d3nv44bhcJCGHQ2PnzZOXU3wYoH5dWlNRUaFLL71UWVlZysvL03XXXaeGhoaobTo6OlReXq5Ro0YpMzNTc+fOVfOn/t5MY2OjSktLlZ6erry8PN1xxx3q7u4++2cDoF/a33hDf/7pT9XV2hrzsR2pqcr/7nc19sYb5Ujh92EMTL8iVV1drfLycr322muqrKxUV1eX5syZoxMnTkS2ue222/Tcc89p8+bNqq6u1pEjR3TDDTdE1vf09Ki0tFSdnZ169dVX9cQTT2jDhg1avXp17J4VgC8U2LtX76xdG7fPQXmuvVaeuXMJFM6KwxhjBnrnDz74QHl5eaqurtZf/dVfKRAIaMyYMXrqqad04/9dYnro0CFddNFFqqmp0axZs7R9+3b97d/+rY4cOSKPxyNJevTRR/XDH/5QH3zwgdLS0r7wcYPBoNxutwKBgFwu10CnD5yTTp/ie/fBB9X10UexfwCHQ2O++U19acECJaenx358DAt9fR0/q0/SBQIBSVJubq4kqba2Vl1dXSouLo5sM2nSJBUWFqqmpkaSVFNToylTpkQCJUklJSUKBoM6cOBAr48TCoUUDAajbgAGpqOpSX/+l3+JS6AcKSkaU1qq/O9+V0kjR8Z8fJx7BhypcDisFStW6PLLL9fkyZMlSX6/X2lpacrOzo7a1uPxyO/3R7b5ZKBOrz+9rjcVFRVyu92RW0FBwUCnDZzTQn6/3v/Nb2Ti9B5w7lVXaezNNyslK4uvO0JMDDhS5eXl2r9/vzZt2hTL+fRq1apVCgQCkVtTU1PcHxMYTowxCrW06P3/+A8F9u6N/QMkJSn7iiuU/73vKZVT8IihAb2juXTpUm3btk0vvfSSxo0bF1nu9XrV2dmptra2qKOp5uZmeb3eyDa7d++OGu/01X+nt/k0p9Mpp9M5kKkCkNTV2qq//PrXavu/0+4xlZSknMsv15fKypT6f6f+gVjp15GUMUZLly7Vli1btGPHDk2YMCFq/fTp05WamqqqqqrIsoaGBjU2Nsrn80mSfD6f6uvr1dLSEtmmsrJSLpdLRUVFZ/NcAPSi66OPdOTJJ9X22mtxGT85PV1fWrBAzrw8TvEh5vp1JFVeXq6nnnpKzz77rLKysiLvIbndbo0cOVJut1uLFi3SypUrlZubK5fLpWXLlsnn82nWrFmSpDlz5qioqEgLFizQ2rVr5ff7dffdd6u8vJyjJSCGjDHqDgR0ZONGtb74ojTwC3k/U9KIETr/n/5JaZ96nxmIlX5dgv5ZvyU9/vjj+ru/+ztJpz7Me/vtt2vjxo0KhUIqKSnRunXrok7lvffee1qyZIlefPFFZWRkqKysTGvWrFFKHz9PwSXowBfrPn5cf9mwQcf+8Ie4jJ+am6vzli2T65JLOIJCv/X1dfysPieVKEQK+Hw9J0/qyMaNann22bgcQTnS0nTBqlVyT58e87Fxbujr6zgfBQeGEWOMwidP6uhvf6sPtm2LS6CSMzJ0/g9/qKxp02I+NvBpRAoYRkxPj/xPP63mZ56Jy/ipOTk6b+lSZU2bxik+DAoiBQwTJhxW85Ytat6yJS7jO1JSdN7y5XJfcklcxgd6Q6SAYSDc1aXmrVt1dONGmZ6emI+fnJWl8//xH5X11a/GfGzg8xApYIgzxqh5yxYd+c//jMv4KW73qav4Lr44LuMDn+esvmAWQGKdDtTReH09WXKyCpcskXvGjPiMD3wBjqSAISrc1XUqUPE6xZeZqYJbb1X2zJlyJPH7LBKDSAFDVPPWrXE7xZeckaEvff/7yvna1+RITo7LYwB9QaSAIcaEw6cukojXKb6kJI2dN0+jvv51JfXhj5AC8USkgCEk3N2t5t/9Tkc2bpTC4ZiPn5SeLu+NN2rM1VcTKFiBSAFDSMvWrTry5JNxGdvhdMrz7W8TKFiFSAFDgOnpOfUeVBz/yOiYq69W3rXXKiUjI26PAfQXkQIsd/qrjo489VRcvotPkkacd568N95IoGAdIgVY7vjBg3ENVPoFF2j8bbcp1e2Oy/jA2SBSgKVMOKwTb72lpscei1ugnF6vLrjrLqWNGROX8YGzRaQACxljdPzQITX96lc6efhwXB5j5Pnn68J77lHaqFFxGR+IBT5GDlioo6lJf/n1r3XynXfiMv7I88/XhJUrCRSsx5EUYBFjjDoaG/XeunX6+O234/IYTq+XIygMGUQKsMTpQDU++qhOvPlmXB4j/YILTr0HRaAwRHC6D7CAMUbdbW1q+tWvdPzAgbg8xojCQo1fsYKLJDCkcCQFJJgxRp0ffKD3fvELtdfXx+Ux0vLydOHddyvN44nL+EC8ECkggYwx6mptVeOjj6q9ri4ujxE5gvJ45HA44vIYQLxwug9IINPdrcZ16xTcuzcu46d5PCq45RalT5hAoDAkcSQFJIjp6dF7Dz+swJ49cRk/NTdX42+7TZmTJvFHCzFkESkgQU689ZZOHDp0xvL3T5zQvtZWtXd1acyIEfKNGaOM1NR+jZ3m8ahg0SIChSGPSAEJEO7qUmDvXoX8/sgyY4wOHz+ue/ft07vHj6ujp0eu1FRNzsnRv156qVL7GJuU7Gzlf+97ck2fTqAw5PEvGEiArmPH1PzMM1HL3jl+XLf88Y96MxDQyZ4eGUmBri79saVFy3ft0rGOji8e2OHQuEWLlHvllUrq59EXYCMiBSSAMUampydq2YMHDijQ1dXr9rs//FCVR4584bg5V1yhHJ+PIygMG5zuA4YDh0M5V1yhwsWL+au6GFaIFDAM5FxxhSasXClHcnKipwLEFOcEAEuUFhQo9TM+yzQ+M1NTc3N7XZdz5ZU6b9kyAoVhiUgBCZDqdiv361+PWlaSn697L75YI5KTI//HTHY4NMrp1M8uvVRF2dnRgzgcyrnyShUuXqzkESMGZd7AYON0H5AASSNHKsfnU2DvXvW0t0uSHA6HSvLzNS49Xdv+8hcd6+jQ+MxM3TRhgkY5nWeMkXPFFZpw++1cJIFhjUgBCeBwOOS6+GKNueYaNf/ud5Er/RwOhybn5GhyTs7n3j/nyit1Xnk5gcKwx79wIEGSnE55rr9eud/4hhwpfft90eF0njrFt2SJktPT4zxDIPE4kgISKDk9XeN+8AOl5uaqdedOdba09LqdIzlZIwoKTkXtr/+aL4vFOYNIAQnkcDiUkpGhsTfeKNfUqfro1Vd1/MABhfx+hUMhJWdmasS4cXLPmCH3jBkaWVhIoHBOIVKABZKcTmVOnqyML39ZPSdPynR3y4TDciQnKyk1VUnp6Urq4ylBYDjhXz1gCYfDIYfTqaReruQDzlVcOAEAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGCtfkVq/fr1mjp1qlwul1wul3w+n7Zv3x5Z39HRofLyco0aNUqZmZmaO3eumpubo8ZobGxUaWmp0tPTlZeXpzvuuEPd3d2xeTYAgGGlX5EaN26c1qxZo9raWu3du1ff+MY3dO211+rAgQOSpNtuu03PPfecNm/erOrqah05ckQ33HBD5P49PT0qLS1VZ2enXn31VT3xxBPasGGDVq9eHdtnBQAYHsxZysnJMY899phpa2szqampZvPmzZF1b775ppFkampqjDHGPP/88yYpKcn4/f7INuvXrzcul8uEQqE+P2YgEDCSTCAQONvpAwASoK+v4wN+T6qnp0ebNm3SiRMn5PP5VFtbq66uLhUXF0e2mTRpkgoLC1VTUyNJqqmp0ZQpU+TxeCLblJSUKBgMRo7GehMKhRQMBqNuAIDhr9+Rqq+vV2ZmppxOpxYvXqwtW7aoqKhIfr9faWlpys7Ojtre4/HI7/dLkvx+f1SgTq8/ve6zVFRUyO12R24FBQX9nTYAYAjqd6S+8pWvqK6uTrt27dKSJUtUVlamgwcPxmNuEatWrVIgEIjcmpqa4vp4AAA7pPT3DmlpabrwwgslSdOnT9eePXv085//XDfddJM6OzvV1tYWdTTV3Nwsr9crSfJ6vdq9e3fUeKev/ju9TW+cTqecTmd/pwoAGOLO+nNS4XBYoVBI06dPV2pqqqqqqiLrGhoa1NjYKJ/PJ0ny+Xyqr69XS0tLZJvKykq5XC4VFRWd7VQAAMNMv46kVq1apWuuuUaFhYVqb2/XU089pRdffFEvvPCC3G63Fi1apJUrVyo3N1cul0vLli2Tz+fTrFmzJElz5sxRUVGRFixYoLVr18rv9+vuu+9WeXk5R0oAgDP0K1ItLS36/ve/r6NHj8rtdmvq1Kl64YUX9Dd/8zeSpAceeEBJSUmaO3euQqGQSkpKtG7dusj9k5OTtW3bNi1ZskQ+n08ZGRkqKyvTT37yk9g+KwDAsOAwxphET6K/gsGg3G63AoGAXC5XoqcDAOinvr6O8919AABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKx1VpFas2aNHA6HVqxYEVnW0dGh8vJyjRo1SpmZmZo7d66am5uj7tfY2KjS0lKlp6crLy9Pd9xxh7q7u89mKgCAYWjAkdqzZ49++ctfaurUqVHLb7vtNj333HPavHmzqqurdeTIEd1www2R9T09PSotLVVnZ6deffVVPfHEE9qwYYNWr1498GcBABiezAC0t7ebiRMnmsrKSnPVVVeZ5cuXG2OMaWtrM6mpqWbz5s2Rbd98800jydTU1BhjjHn++edNUlKS8fv9kW3Wr19vXC6XCYVCfXr8QCBgJJlAIDCQ6QMAEqyvr+MDOpIqLy9XaWmpiouLo5bX1taqq6sravmkSZNUWFiompoaSVJNTY2mTJkij8cT2aakpETBYFAHDhzo9fFCoZCCwWDUDQAw/KX09w6bNm3S66+/rj179pyxzu/3Ky0tTdnZ2VHLPR6P/H5/ZJtPBur0+tPrelNRUaEf//jH/Z0qAGCI69eRVFNTk5YvX64nn3xSI0aMiNeczrBq1SoFAoHIrampadAeGwCQOP2KVG1trVpaWnTJJZcoJSVFKSkpqq6u1kMPPaSUlBR5PB51dnaqra0t6n7Nzc3yer2SJK/Xe8bVfqd/Pr3NpzmdTrlcrqgbAGD461ekZs+erfr6etXV1UVuM2bM0Pz58yP/nZqaqqqqqsh9Ghoa1NjYKJ/PJ0ny+Xyqr69XS0tLZJvKykq5XC4VFRXF6GkBAIaDfr0nlZWVpcmTJ0cty8jI0KhRoyLLFy1apJUrVyo3N1cul0vLli2Tz+fTrFmzJElz5sxRUVGRFixYoLVr18rv9+vuu+9WeXm5nE5njJ4WAGA46PeFE1/kgQceUFJSkubOnatQKKSSkhKtW7cusj45OVnbtm3TkiVL5PP5lJGRobKyMv3kJz+J9VQAAEOcwxhjEj2J/goGg3K73QoEArw/BQBDUF9fx/nuPgCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtVISPYGBMMZIkoLBYIJnAgAYiNOv36dfzz/LkIzUsWPHJEkFBQUJngkA4Gy0t7fL7XZ/5vohGanc3FxJUmNj4+c+uXNdMBhUQUGBmpqa5HK5Ej0da7Gf+ob91Dfsp74xxqi9vV35+fmfu92QjFRS0qm30txuN/8I+sDlcrGf+oD91Dfsp75hP32xvhxkcOEEAMBaRAoAYK0hGSmn06l7771XTqcz0VOxGvupb9hPfcN+6hv2U2w5zBdd/wcAQIIMySMpAMC5gUgBAKxFpAAA1iJSAABrDclIPfLIIxo/frxGjBihmTNnavfu3Yme0qB66aWX9K1vfUv5+flyOBzaunVr1HpjjFavXq2xY8dq5MiRKi4u1ttvvx21TWtrq+bPny+Xy6Xs7GwtWrRIx48fH8RnEV8VFRW69NJLlZWVpby8PF133XVqaGiI2qajo0Pl5eUaNWqUMjMzNXfuXDU3N0dt09jYqNLSUqWnpysvL0933HGHuru7B/OpxNX69es1derUyAdPfT6ftm/fHlnPPurdmjVr5HA4tGLFisgy9lWcmCFm06ZNJi0tzfz7v/+7OXDggLnllltMdna2aW5uTvTUBs3zzz9vfvSjH5lnnnnGSDJbtmyJWr9mzRrjdrvN1q1bzZ/+9Cfz7W9/20yYMMGcPHkyss3VV19tpk2bZl577TXz8ssvmwsvvNDMmzdvkJ9J/JSUlJjHH3/c7N+/39TV1ZlvfvObprCw0Bw/fjyyzeLFi01BQYGpqqoye/fuNbNmzTJf+9rXIuu7u7vN5MmTTXFxsdm3b595/vnnzejRo82qVasS8ZTi4r//+7/N//zP/5i33nrLNDQ0mLvuusukpqaa/fv3G2PYR73ZvXu3GT9+vJk6dapZvnx5ZDn7Kj6GXKQuu+wyU15eHvm5p6fH5Ofnm4qKigTOKnE+HalwOGy8Xq+57777Isva2tqM0+k0GzduNMYYc/DgQSPJ7NmzJ7LN9u3bjcPhMO+///6gzX0wtbS0GEmmurraGHNqn6SmpprNmzdHtnnzzTeNJFNTU2OMOfXLQFJSkvH7/ZFt1q9fb1wulwmFQoP7BAZRTk6Oeeyxx9hHvWhvbzcTJ040lZWV5qqrropEin0VP0PqdF9nZ6dqa2tVXFwcWZaUlKTi4mLV1NQkcGb2OHz4sPx+f9Q+crvdmjlzZmQf1dTUKDs7WzNmzIhsU1xcrKSkJO3atWvQ5zwYAoGApP//5cS1tbXq6uqK2k+TJk1SYWFh1H6aMmWKPB5PZJuSkhIFg0EdOHBgEGc/OHp6erRp0yadOHFCPp+PfdSL8vJylZaWRu0TiX9P8TSkvmD2ww8/VE9PT9T/yJLk8Xh06NChBM3KLn6/X5J63Uen1/n9fuXl5UWtT0lJUW5ubmSb4SQcDmvFihW6/PLLNXnyZEmn9kFaWpqys7Ojtv30fuptP55eN1zU19fL5/Opo6NDmZmZ2rJli4qKilRXV8c++oRNmzbp9ddf1549e85Yx7+n+BlSkQIGory8XPv379crr7yS6KlY6Stf+Yrq6uoUCAT09NNPq6ysTNXV1YmellWampq0fPlyVVZWasSIEYmezjllSJ3uGz16tJKTk8+4Yqa5uVlerzdBs7LL6f3wefvI6/WqpaUlan13d7daW1uH3X5cunSptm3bpp07d2rcuHGR5V6vV52dnWpra4va/tP7qbf9eHrdcJGWlqYLL7xQ06dPV0VFhaZNm6af//zn7KNPqK2tVUtLiy655BKlpKQoJSVF1dXVeuihh5SSkiKPx8O+ipMhFam0tDRNnz5dVVVVkWXhcFhVVVXy+XwJnJk9JkyYIK/XG7WPgsGgdu3aFdlHPp9PbW1tqq2tjWyzY8cOhcNhzZw5c9DnHA/GGC1dulRbtmzRjh07NGHChKj106dPV2pqatR+amhoUGNjY9R+qq+vjwp6ZWWlXC6XioqKBueJJEA4HFYoFGIffcLs2bNVX1+vurq6yG3GjBmaP39+5L/ZV3GS6Cs3+mvTpk3G6XSaDRs2mIMHD5pbb73VZGdnR10xM9y1t7ebffv2mX379hlJ5v777zf79u0z7733njHm1CXo2dnZ5tlnnzVvvPGGufbaa3u9BP3iiy82u3btMq+88oqZOHHisLoEfcmSJcbtdpsXX3zRHD16NHL7+OOPI9ssXrzYFBYWmh07dpi9e/can89nfD5fZP3pS4bnzJlj6urqzO9//3szZsyYYXXJ8J133mmqq6vN4cOHzRtvvGHuvPNO43A4zB/+8AdjDPvo83zy6j5j2FfxMuQiZYwxDz/8sCksLDRpaWnmsssuM6+99lqipzSodu7caSSdcSsrKzPGnLoM/Z577jEej8c4nU4ze/Zs09DQEDXGsWPHzLx580xmZqZxuVxm4cKFpr29PQHPJj562z+SzOOPPx7Z5uTJk+Yf/uEfTE5OjklPTzfXX3+9OXr0aNQ47777rrnmmmvMyJEjzejRo83tt99uurq6BvnZxM8PfvADc95555m0tDQzZswYM3v27EigjGEffZ5PR4p9FR/8qQ4AgLWG1HtSAIBzC5ECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADW+n/0F3JV0ezNQwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.imshow(env.render())\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(-0.08786074072122574, -1349.5105751985445)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import random\n",
    "from IPython import display\n",
    "import math\n",
    "\n",
    "\n",
    "class SAC:\n",
    "    class ModelAction(torch.nn.Module):\n",
    "        def __init__(self):\n",
    "            super().__init__()\n",
    "            self.fc_state = torch.nn.Sequential(\n",
    "                torch.nn.Linear(3, 128),\n",
    "                torch.nn.ReLU(),\n",
    "            )\n",
    "            self.fc_mu = torch.nn.Linear(128, 1)\n",
    "            self.fc_std = torch.nn.Sequential(\n",
    "                torch.nn.Linear(128, 1),\n",
    "                torch.nn.Softplus(),\n",
    "            )\n",
    "\n",
    "        def forward(self, state):\n",
    "            #[b, 3] -> [b, 128]\n",
    "            state = self.fc_state(state)\n",
    "\n",
    "            #[b, 128] -> [b, 1]\n",
    "            mu = self.fc_mu(state)\n",
    "\n",
    "            #[b, 128] -> [b, 1]\n",
    "            std = self.fc_std(state)\n",
    "\n",
    "            #根据mu和std定义b个正态分布\n",
    "            dist = torch.distributions.Normal(mu, std)\n",
    "\n",
    "            #采样b个样本\n",
    "            #这里用的是rsample,表示重采样,其实就是先从一个标准正态分布中采样,然后乘以标准差,加上均值\n",
    "            sample = dist.rsample()\n",
    "\n",
    "            #样本压缩到-1,1之间,求动作\n",
    "            action = torch.tanh(sample)\n",
    "\n",
    "            #求概率对数\n",
    "            log_prob = dist.log_prob(sample)\n",
    "\n",
    "            #这个式子看不懂,但参照上下文理解,这个值应该描述的是动作的熵\n",
    "            entropy = log_prob - (1 - action.tanh()**2 + 1e-7).log()\n",
    "            entropy = -entropy\n",
    "\n",
    "            return action * 2, entropy\n",
    "\n",
    "    class ModelValue(torch.nn.Module):\n",
    "        def __init__(self):\n",
    "            super().__init__()\n",
    "            self.sequential = torch.nn.Sequential(\n",
    "                torch.nn.Linear(4, 128),\n",
    "                torch.nn.ReLU(),\n",
    "                torch.nn.Linear(128, 128),\n",
    "                torch.nn.ReLU(),\n",
    "                torch.nn.Linear(128, 1),\n",
    "            )\n",
    "\n",
    "        def forward(self, state, action):\n",
    "            #[b, 3+1] -> [b, 4]\n",
    "            state = torch.cat([state, action], dim=1)\n",
    "\n",
    "            #[b, 4] -> [b, 1]\n",
    "            return self.sequential(state)\n",
    "\n",
    "    def __init__(self):\n",
    "        self.model_action = self.ModelAction()\n",
    "\n",
    "        self.model_value1 = self.ModelValue()\n",
    "        self.model_value2 = self.ModelValue()\n",
    "\n",
    "        self.model_value_next1 = self.ModelValue()\n",
    "        self.model_value_next2 = self.ModelValue()\n",
    "\n",
    "        self.model_value_next1.load_state_dict(self.model_value1.state_dict())\n",
    "        self.model_value_next2.load_state_dict(self.model_value2.state_dict())\n",
    "\n",
    "        #这也是一个可学习的参数\n",
    "        self.alpha = torch.tensor(math.log(0.01))\n",
    "        self.alpha.requires_grad = True\n",
    "\n",
    "        self.optimizer_action = torch.optim.Adam(\n",
    "            self.model_action.parameters(), lr=3e-4)\n",
    "        self.optimizer_value1 = torch.optim.Adam(\n",
    "            self.model_value1.parameters(), lr=3e-3)\n",
    "        self.optimizer_value2 = torch.optim.Adam(\n",
    "            self.model_value2.parameters(), lr=3e-3)\n",
    "\n",
    "        #alpha也是要更新的参数,所以这里要定义优化器\n",
    "        self.optimizer_alpha = torch.optim.Adam([self.alpha], lr=3e-4)\n",
    "\n",
    "        self.loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    def get_action(self, state):\n",
    "        state = torch.FloatTensor(state).reshape(1, 3)\n",
    "        action, _ = self.model_action(state)\n",
    "        return action.item()\n",
    "\n",
    "    def test(self, play):\n",
    "        #初始化游戏\n",
    "        state = env.reset()\n",
    "\n",
    "        #记录反馈值的和,这个值越大越好\n",
    "        reward_sum = 0\n",
    "\n",
    "        #玩到游戏结束为止\n",
    "        over = False\n",
    "        while not over:\n",
    "            #根据当前状态得到一个动作\n",
    "            action = self.get_action(state)\n",
    "\n",
    "            #执行动作,得到反馈\n",
    "            state, reward, over, _ = env.step([action])\n",
    "            reward_sum += reward\n",
    "\n",
    "            #打印动画\n",
    "            if play and random.random() < 0.2:  #跳帧\n",
    "                display.clear_output(wait=True)\n",
    "                show()\n",
    "\n",
    "        return reward_sum\n",
    "\n",
    "    def _soft_update(self, model, model_next):\n",
    "        for param, param_next in zip(model.parameters(),\n",
    "                                     model_next.parameters()):\n",
    "            #以一个小的比例更新\n",
    "            value = param_next.data * 0.995 + param.data * 0.005\n",
    "            param_next.data.copy_(value)\n",
    "\n",
    "    def _get_target(self, reward, next_state, over):\n",
    "        #首先使用model_action计算动作和动作的熵\n",
    "        #[b, 4] -> [b, 1],[b, 1]\n",
    "        action, entropy = self.model_action(next_state)\n",
    "\n",
    "        #评估next_state的价值\n",
    "        #[b, 4],[b, 1] -> [b, 1]\n",
    "        target1 = self.model_value_next1(next_state, action)\n",
    "        target2 = self.model_value_next2(next_state, action)\n",
    "\n",
    "        #取价值小的,这是出于稳定性考虑\n",
    "        #[b, 1]\n",
    "        target = torch.min(target1, target2)\n",
    "\n",
    "        #exp和log互为反操作,这里是把alpha还原了\n",
    "        #这里的操作是在target上加上了动作的熵,alpha作为权重系数\n",
    "        #[b, 1] - [b, 1] -> [b, 1]\n",
    "        target += self.alpha.exp() * entropy\n",
    "\n",
    "        #[b, 1]\n",
    "        target *= 0.99\n",
    "        target *= (1 - over)\n",
    "        target += reward\n",
    "\n",
    "        return target\n",
    "\n",
    "    def _get_loss_action(self, state):\n",
    "        #计算action和熵\n",
    "        #[b, 3] -> [b, 1],[b, 1]\n",
    "        action, entropy = self.model_action(state)\n",
    "\n",
    "        #使用两个value网络评估action的价值\n",
    "        #[b, 3],[b, 1] -> [b, 1]\n",
    "        value1 = self.model_value1(state, action)\n",
    "        value2 = self.model_value2(state, action)\n",
    "\n",
    "        #取价值小的,出于稳定性考虑\n",
    "        #[b, 1]\n",
    "        value = torch.min(value1, value2)\n",
    "\n",
    "        #alpha还原后乘以熵,这个值期望的是越大越好,但是这里是计算loss,所以符号取反\n",
    "        #[1] - [b, 1] -> [b, 1]\n",
    "        loss_action = -self.alpha.exp() * entropy\n",
    "\n",
    "        #减去value,所以value越大越好,这样loss就会越小\n",
    "        loss_action -= value\n",
    "\n",
    "        return loss_action.mean(), entropy\n",
    "\n",
    "    def _get_loss_value(self, model_value, target, state, action, next_state):\n",
    "        #计算value\n",
    "        value = model_value(state, action)\n",
    "\n",
    "        #计算loss,value的目标是要贴近target\n",
    "        loss_value = self.loss_fn(value, target)\n",
    "        return loss_value\n",
    "\n",
    "    def train(self, state, action, reward, next_state, over):\n",
    "        #对reward偏移,为了便于训练\n",
    "        reward = (reward + 8) / 8\n",
    "\n",
    "        #计算target,这个target里已经考虑了动作的熵\n",
    "        #[b, 1]\n",
    "        target = self._get_target(reward, next_state, over)\n",
    "        target = target.detach()\n",
    "\n",
    "        #计算两个value loss\n",
    "        loss_value1 = self._get_loss_value(self.model_value1, target, state,\n",
    "                                           action, next_state)\n",
    "        loss_value2 = self._get_loss_value(self.model_value2, target, state,\n",
    "                                           action, next_state)\n",
    "\n",
    "        #更新参数\n",
    "        self.optimizer_value1.zero_grad()\n",
    "        loss_value1.backward()\n",
    "        self.optimizer_value1.step()\n",
    "\n",
    "        self.optimizer_value2.zero_grad()\n",
    "        loss_value2.backward()\n",
    "        self.optimizer_value2.step()\n",
    "\n",
    "        #使用model_value计算model_action的loss\n",
    "        loss_action, entropy = self._get_loss_action(state)\n",
    "        self.optimizer_action.zero_grad()\n",
    "        loss_action.backward()\n",
    "        self.optimizer_action.step()\n",
    "\n",
    "        #熵乘以alpha就是alpha的loss\n",
    "        #[b, 1] -> [1]\n",
    "        loss_alpha = (entropy + 1).detach() * self.alpha.exp()\n",
    "        loss_alpha = loss_alpha.mean()\n",
    "\n",
    "        #更新alpha值\n",
    "        self.optimizer_alpha.zero_grad()\n",
    "        loss_alpha.backward()\n",
    "        self.optimizer_alpha.step()\n",
    "\n",
    "        #增量更新next模型\n",
    "        self._soft_update(self.model_value1, self.model_value_next1)\n",
    "        self._soft_update(self.model_value2, self.model_value_next2)\n",
    "\n",
    "\n",
    "teacher = SAC()\n",
    "\n",
    "teacher.train(\n",
    "    torch.randn(5, 3),\n",
    "    torch.randn(5, 1),\n",
    "    torch.randn(5, 1),\n",
    "    torch.randn(5, 3),\n",
    "    torch.zeros(5, 1).long(),\n",
    ")\n",
    "\n",
    "teacher.get_action([1, 2, 3]), teacher.test(play=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_3784/2919648279.py:36: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  ../torch/csrc/utils/tensor_new.cpp:201.)\n",
      "  state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(None,\n",
       " (tensor([[ 3.2446e-01, -9.4590e-01,  6.6066e+00],\n",
       "          [ 1.5550e-01,  9.8784e-01,  4.9625e+00],\n",
       "          [-6.9978e-01, -7.1436e-01,  7.9689e+00],\n",
       "          [ 7.9677e-01, -6.0429e-01,  5.7067e+00],\n",
       "          [-3.8083e-01,  9.2464e-01,  8.0000e+00],\n",
       "          [-9.2212e-01, -3.8690e-01,  8.0000e+00],\n",
       "          [ 9.9256e-01, -1.2179e-01,  5.4723e+00],\n",
       "          [ 8.9651e-01,  4.4302e-01,  5.8756e+00],\n",
       "          [-3.6686e-02,  9.9933e-01,  8.0000e+00],\n",
       "          [-3.7479e-01, -9.2711e-01,  7.6718e+00],\n",
       "          [ 8.8221e-01,  4.7086e-01,  3.9383e-01],\n",
       "          [ 8.1295e-01,  5.8233e-01, -1.4300e+00],\n",
       "          [ 3.8096e-01,  9.2459e-01,  7.2567e+00],\n",
       "          [ 1.3075e-02,  9.9991e-01,  8.0000e+00],\n",
       "          [-9.1918e-01, -3.9384e-01,  8.0000e+00],\n",
       "          [-9.4471e-01,  3.2792e-01,  8.0000e+00],\n",
       "          [-8.9342e-01, -4.4923e-01,  8.0000e+00],\n",
       "          [-2.0781e-01, -9.7817e-01, -3.7649e+00],\n",
       "          [ 3.4587e-01, -9.3828e-01,  6.5733e+00],\n",
       "          [ 4.7267e-01,  8.8124e-01, -3.6620e+00],\n",
       "          [ 2.4390e-02,  9.9970e-01,  8.0000e+00],\n",
       "          [ 5.9446e-01, -8.0413e-01,  6.1231e+00],\n",
       "          [ 9.8037e-01,  1.9717e-01,  5.5775e+00],\n",
       "          [ 9.9408e-01, -1.0868e-01,  5.4630e+00],\n",
       "          [-4.6499e-01,  8.8531e-01, -6.3669e+00],\n",
       "          [ 8.9246e-01,  4.5112e-01,  5.9519e+00],\n",
       "          [-9.9472e-01, -1.0259e-01,  8.0000e+00],\n",
       "          [-6.9474e-01,  7.1926e-01,  8.0000e+00],\n",
       "          [-7.0819e-01,  7.0603e-01,  8.0000e+00],\n",
       "          [-9.0741e-01,  4.2025e-01, -6.7880e+00],\n",
       "          [ 3.9797e-01,  9.1740e-01,  7.3353e+00],\n",
       "          [ 8.9745e-01,  4.4112e-01,  5.9576e+00],\n",
       "          [ 9.8739e-01,  1.5830e-01,  5.6212e+00],\n",
       "          [ 5.7446e-01, -8.1853e-01,  6.1572e+00],\n",
       "          [-8.7625e-01, -4.8185e-01,  8.0000e+00],\n",
       "          [-6.4364e-01, -7.6533e-01, -5.3382e+00],\n",
       "          [-9.0544e-01, -4.2447e-01,  8.0000e+00],\n",
       "          [-7.0015e-01,  7.1400e-01,  8.0000e+00],\n",
       "          [ 9.2648e-01, -3.7634e-01,  5.4953e+00],\n",
       "          [ 6.1176e-01, -7.9104e-01,  6.1024e+00],\n",
       "          [-1.0000e+00,  2.7377e-03,  8.0000e+00],\n",
       "          [ 1.4047e-01, -9.9009e-01, -2.3131e+00],\n",
       "          [-9.9999e-01, -4.8109e-03,  8.0000e+00],\n",
       "          [ 2.7243e-02,  9.9963e-01,  8.0000e+00],\n",
       "          [ 2.4500e-02, -9.9970e-01,  7.1057e+00],\n",
       "          [ 4.1863e-01,  9.0815e-01,  7.1982e+00],\n",
       "          [ 7.0990e-01,  7.0430e-01,  6.4494e+00],\n",
       "          [ 7.9144e-01, -6.1125e-01,  5.7661e+00],\n",
       "          [-3.5980e-01,  9.3303e-01,  8.0000e+00],\n",
       "          [ 8.6255e-01,  5.0597e-01,  8.0482e-01],\n",
       "          [-9.2182e-01,  3.8761e-01,  8.0000e+00],\n",
       "          [ 8.8190e-01,  4.7144e-01, -5.4626e-01],\n",
       "          [ 7.0266e-01,  7.1153e-01, -2.4918e+00],\n",
       "          [-9.1683e-03,  9.9996e-01,  8.0000e+00],\n",
       "          [ 3.7485e-01, -9.2708e-01,  6.5285e+00],\n",
       "          [ 5.7830e-01, -8.1583e-01,  6.2212e+00],\n",
       "          [ 7.8848e-01, -6.1506e-01,  5.8339e+00],\n",
       "          [-9.2722e-01,  3.7451e-01,  8.0000e+00],\n",
       "          [ 3.1077e-01, -9.5048e-01,  6.6419e+00],\n",
       "          [ 5.7034e-01,  8.2141e-01,  3.1805e+00],\n",
       "          [ 9.9833e-01, -5.7747e-02,  5.4262e+00],\n",
       "          [ 3.0608e-01, -9.5201e-01,  6.6907e+00],\n",
       "          [-3.4875e-01, -9.3722e-01,  7.6013e+00],\n",
       "          [-3.6684e-01,  9.3029e-01,  8.0000e+00]]),\n",
       "  tensor([[ 1.5056],\n",
       "          [ 1.4822],\n",
       "          [ 1.6984],\n",
       "          [ 1.4614],\n",
       "          [ 1.7266],\n",
       "          [ 1.7273],\n",
       "          [ 1.6015],\n",
       "          [ 1.6103],\n",
       "          [ 1.8060],\n",
       "          [ 1.5844],\n",
       "          [ 0.3856],\n",
       "          [-1.0884],\n",
       "          [ 1.6643],\n",
       "          [ 1.7491],\n",
       "          [ 1.7330],\n",
       "          [ 1.7610],\n",
       "          [ 1.6097],\n",
       "          [ 0.2560],\n",
       "          [ 1.5517],\n",
       "          [-1.6227],\n",
       "          [ 1.7668],\n",
       "          [ 1.2448],\n",
       "          [ 1.3106],\n",
       "          [ 1.3664],\n",
       "          [-1.4150],\n",
       "          [ 1.4360],\n",
       "          [ 1.7247],\n",
       "          [ 1.8221],\n",
       "          [ 1.6738],\n",
       "          [-1.5287],\n",
       "          [ 1.6518],\n",
       "          [ 1.4579],\n",
       "          [ 1.4515],\n",
       "          [ 1.3943],\n",
       "          [ 1.6880],\n",
       "          [-1.4024],\n",
       "          [ 1.6054],\n",
       "          [ 1.7571],\n",
       "          [ 1.4361],\n",
       "          [ 1.3610],\n",
       "          [ 1.8087],\n",
       "          [-1.2322],\n",
       "          [ 1.7493],\n",
       "          [ 1.6710],\n",
       "          [ 1.4495],\n",
       "          [ 1.7063],\n",
       "          [ 1.7643],\n",
       "          [ 1.3042],\n",
       "          [ 1.8144],\n",
       "          [ 0.5727],\n",
       "          [ 1.7227],\n",
       "          [-1.0074],\n",
       "          [-0.0899],\n",
       "          [ 1.7075],\n",
       "          [ 1.3341],\n",
       "          [ 1.4974],\n",
       "          [ 1.3420],\n",
       "          [ 1.7387],\n",
       "          [ 1.4522],\n",
       "          [ 1.7125],\n",
       "          [ 1.1830],\n",
       "          [ 1.3861],\n",
       "          [ 1.5228],\n",
       "          [ 1.7855]]),\n",
       "  tensor([[ -5.9055],\n",
       "          [ -4.4661],\n",
       "          [-11.8564],\n",
       "          [ -3.6798],\n",
       "          [-10.2504],\n",
       "          [-13.9343],\n",
       "          [ -3.0120],\n",
       "          [ -3.6655],\n",
       "          [ -8.9873],\n",
       "          [ -9.7100],\n",
       "          [ -0.2560],\n",
       "          [ -0.5921],\n",
       "          [ -6.6611],\n",
       "          [ -8.8296],\n",
       "          [-13.8930],\n",
       "          [-14.2851],\n",
       "          [-13.5619],\n",
       "          [ -4.5863],\n",
       "          [ -5.8059],\n",
       "          [ -2.5068],\n",
       "          [ -8.7945],\n",
       "          [ -4.6235],\n",
       "          [ -3.1520],\n",
       "          [ -2.9982],\n",
       "          [ -8.2763],\n",
       "          [ -3.7636],\n",
       "          [-15.6374],\n",
       "          [-11.8736],\n",
       "          [-11.9616],\n",
       "          [-11.9426],\n",
       "          [ -6.7325],\n",
       "          [ -3.7601],\n",
       "          [ -3.1871],\n",
       "          [ -4.7124],\n",
       "          [-13.3662],\n",
       "          [ -8.0047],\n",
       "          [-13.7100],\n",
       "          [-11.9087],\n",
       "          [ -3.1707],\n",
       "          [ -4.5584],\n",
       "          [-16.2557],\n",
       "          [ -2.5811],\n",
       "          [-16.2425],\n",
       "          [ -8.7853],\n",
       "          [ -7.4422],\n",
       "          [ -6.4814],\n",
       "          [ -4.7732],\n",
       "          [ -3.7589],\n",
       "          [-10.1625],\n",
       "          [ -0.3465],\n",
       "          [-13.9300],\n",
       "          [ -0.2719],\n",
       "          [ -1.2476],\n",
       "          [ -8.8992],\n",
       "          [ -5.6718],\n",
       "          [ -4.7829],\n",
       "          [ -3.8441],\n",
       "          [-14.0080],\n",
       "          [ -5.9881],\n",
       "          [ -1.9435],\n",
       "          [ -2.9491],\n",
       "          [ -6.0654],\n",
       "          [ -9.4938],\n",
       "          [-10.1917]]),\n",
       "  tensor([[ 5.9446e-01, -8.0413e-01,  6.1231e+00],\n",
       "          [-1.3970e-01,  9.9019e-01,  5.9257e+00],\n",
       "          [-3.8083e-01, -9.2465e-01,  7.6879e+00],\n",
       "          [ 9.3042e-01, -3.6649e-01,  5.4727e+00],\n",
       "          [-7.1085e-01,  7.0335e-01,  8.0000e+00],\n",
       "          [-6.9978e-01, -7.1436e-01,  7.9689e+00],\n",
       "          [ 9.8739e-01,  1.5830e-01,  5.6212e+00],\n",
       "          [ 7.0990e-01,  7.0430e-01,  6.4494e+00],\n",
       "          [-4.2295e-01,  9.0615e-01,  8.0000e+00],\n",
       "          [-2.3461e-02, -9.9972e-01,  7.2141e+00],\n",
       "          [ 8.6255e-01,  5.0597e-01,  8.0482e-01],\n",
       "          [ 8.4525e-01,  5.3437e-01, -1.1565e+00],\n",
       "          [-9.1683e-03,  9.9996e-01,  8.0000e+00],\n",
       "          [-3.7734e-01,  9.2607e-01,  8.0000e+00],\n",
       "          [-6.9452e-01, -7.1947e-01,  7.9646e+00],\n",
       "          [-9.9783e-01, -6.5850e-02,  8.0000e+00],\n",
       "          [-6.5158e-01, -7.5858e-01,  7.9045e+00],\n",
       "          [-4.1899e-01, -9.0799e-01, -4.4601e+00],\n",
       "          [ 6.1176e-01, -7.9104e-01,  6.1024e+00],\n",
       "          [ 6.0879e-01,  7.9333e-01, -3.2445e+00],\n",
       "          [-3.6684e-01,  9.3029e-01,  8.0000e+00],\n",
       "          [ 7.9677e-01, -6.0429e-01,  5.7067e+00],\n",
       "          [ 8.8017e-01,  4.7465e-01,  5.9220e+00],\n",
       "          [ 9.8551e-01,  1.6961e-01,  5.5865e+00],\n",
       "          [-1.8677e-01,  9.8240e-01, -5.9151e+00],\n",
       "          [ 7.0149e-01,  7.1268e-01,  6.5056e+00],\n",
       "          [-8.7625e-01, -4.8185e-01,  8.0000e+00],\n",
       "          [-9.1999e-01,  3.9194e-01,  8.0000e+00],\n",
       "          [-9.2722e-01,  3.7451e-01,  8.0000e+00],\n",
       "          [-7.1873e-01,  6.9529e-01, -6.7021e+00],\n",
       "          [ 9.3008e-03,  9.9996e-01,  8.0000e+00],\n",
       "          [ 7.0936e-01,  7.0485e-01,  6.5071e+00],\n",
       "          [ 8.9745e-01,  4.4112e-01,  5.9576e+00],\n",
       "          [ 7.8306e-01, -6.2195e-01,  5.7524e+00],\n",
       "          [-6.2368e-01, -7.8168e-01,  7.8918e+00],\n",
       "          [-8.4436e-01, -5.3577e-01, -6.1226e+00],\n",
       "          [-6.7155e-01, -7.4096e-01,  7.9225e+00],\n",
       "          [-9.2292e-01,  3.8498e-01,  8.0000e+00],\n",
       "          [ 9.9346e-01, -1.1417e-01,  5.4284e+00],\n",
       "          [ 8.0988e-01, -5.8659e-01,  5.7133e+00],\n",
       "          [-9.2212e-01, -3.8690e-01,  8.0000e+00],\n",
       "          [-2.1089e-02, -9.9978e-01, -3.2405e+00],\n",
       "          [-9.1918e-01, -3.9384e-01,  8.0000e+00],\n",
       "          [-3.6418e-01,  9.3133e-01,  8.0000e+00],\n",
       "          [ 3.4587e-01, -9.3828e-01,  6.5733e+00],\n",
       "          [ 3.1936e-02,  9.9949e-01,  8.0000e+00],\n",
       "          [ 4.1437e-01,  9.1011e-01,  7.2423e+00],\n",
       "          [ 9.2774e-01, -3.7322e-01,  5.5033e+00],\n",
       "          [-6.9474e-01,  7.1926e-01,  8.0000e+00],\n",
       "          [ 8.2870e-01,  5.5970e-01,  1.2702e+00],\n",
       "          [-1.0000e+00, -1.9571e-03,  8.0000e+00],\n",
       "          [ 8.8987e-01,  4.5621e-01, -3.4379e-01],\n",
       "          [ 7.6928e-01,  6.3892e-01, -1.9716e+00],\n",
       "          [-3.9785e-01,  9.1745e-01,  8.0000e+00],\n",
       "          [ 6.3337e-01, -7.7385e-01,  6.0333e+00],\n",
       "          [ 7.8848e-01, -6.1506e-01,  5.8339e+00],\n",
       "          [ 9.2726e-01, -3.7441e-01,  5.5739e+00],\n",
       "          [-9.9987e-01, -1.6128e-02,  8.0000e+00],\n",
       "          [ 5.8376e-01, -8.1193e-01,  6.1469e+00],\n",
       "          [ 3.9333e-01,  9.1940e-01,  4.0534e+00],\n",
       "          [ 9.7585e-01,  2.1846e-01,  5.5604e+00],\n",
       "          [ 5.8128e-01, -8.1370e-01,  6.1846e+00],\n",
       "          [ 1.0781e-04, -1.0000e+00,  7.1268e+00],\n",
       "          [-7.0015e-01,  7.1400e-01,  8.0000e+00]]),\n",
       "  tensor([[0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0],\n",
       "          [0]])))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Data:\n",
    "    def __init__(self):\n",
    "        #样本池\n",
    "        self.datas = []\n",
    "\n",
    "    #向样本池中添加N条数据,删除M条最古老的数据\n",
    "    def update_data(self, agent):\n",
    "        #初始化游戏\n",
    "        state = env.reset()\n",
    "\n",
    "        #玩到游戏结束为止\n",
    "        over = False\n",
    "        while not over:\n",
    "            #根据当前状态得到一个动作\n",
    "            action = agent.get_action(state)\n",
    "\n",
    "            #执行动作,得到反馈\n",
    "            next_state, reward, over, _ = env.step([action])\n",
    "\n",
    "            #记录数据样本\n",
    "            self.datas.append((state, action, reward, next_state, over))\n",
    "\n",
    "            #更新游戏状态,开始下一个动作\n",
    "            state = next_state\n",
    "\n",
    "        #数据上限,超出时从最古老的开始删除\n",
    "        while len(self.datas) > 100000:\n",
    "            self.datas.pop(0)\n",
    "\n",
    "    #获取一批数据样本\n",
    "    def get_sample(self):\n",
    "        #从样本池中采样\n",
    "        samples = random.sample(self.datas, 64)\n",
    "\n",
    "        #[b, 3]\n",
    "        state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)\n",
    "        #[b, 1]\n",
    "        action = torch.FloatTensor([i[1] for i in samples]).reshape(-1, 1)\n",
    "        #[b, 1]\n",
    "        reward = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)\n",
    "        #[b, 3]\n",
    "        next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 3)\n",
    "        #[b, 1]\n",
    "        over = torch.LongTensor([i[4] for i in samples]).reshape(-1, 1)\n",
    "\n",
    "        return state, action, reward, next_state, over\n",
    "\n",
    "\n",
    "data = Data()\n",
    "\n",
    "data.update_data(teacher), data.get_sample()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 -1732.6065915213985\n",
      "10 -1329.8038996866576\n",
      "20 -867.4128248787586\n",
      "30 -721.1905882070672\n",
      "40 -898.5599240216588\n",
      "50 -194.0182722683982\n",
      "60 -284.44518420665855\n",
      "70 -1034.186885816949\n",
      "80 -418.3220165054886\n",
      "90 -515.3418884169965\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(100):\n",
    "    #更新N条数据\n",
    "    data.update_data(teacher)\n",
    "\n",
    "    #每次更新过数据后,学习N次\n",
    "    for i in range(200):\n",
    "        teacher.train(*data.get_sample())\n",
    "\n",
    "    if epoch % 10 == 0:\n",
    "        test_result = sum([teacher.test(play=False) for _ in range(10)]) / 10\n",
    "        print(epoch, test_result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.44313228130340576, -1732.9207545522383)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class CQL(SAC):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    def _get_loss_value(self, model_value, target, state, action, next_state):\n",
    "        #计算value\n",
    "        value = model_value(state, action)\n",
    "\n",
    "        #计算loss,value的目标是要贴近target\n",
    "        loss_value = self.loss_fn(value, target)\n",
    "        \"\"\"以上与SAC相同,以下是CQL的部分\"\"\"\n",
    "\n",
    "        #把state复制5遍\n",
    "        #[b, 3] -> [b, 1, 3] -> [b, 5, 3]\n",
    "        state = state.unsqueeze(dim=1)\n",
    "        #[b, 1, 3] -> [b, 5, 3] -> [b*5, 3]\n",
    "        state = state.repeat(1, 5, 1).reshape(-1, 3)\n",
    "\n",
    "        #把next_state复制5遍\n",
    "        #[b, 3] -> [b, 1, 3]\n",
    "        next_state = next_state.unsqueeze(1)\n",
    "        #[b, 1, 3] -> [b, 5, 3] -> [b*5, 3]\n",
    "        next_state = next_state.repeat(1, 5, 1).reshape(-1, 3)\n",
    "\n",
    "        #随机一批动作,数量是数据量的5倍,值域在-1到1之间\n",
    "        rand_action = torch.empty([len(state), 1]).uniform_(-1, 1)\n",
    "\n",
    "        #计算state的动作和熵\n",
    "        #[b*5, 3] -> [b*5, 1],[b*5, 1]\n",
    "        curr_action, curr_entropy = self.model_action(state)\n",
    "\n",
    "        #计算next_state的动作和熵\n",
    "        #[b*5, 3] -> [b*5, 1],[b*5, 1]\n",
    "        next_action, next_entropy = self.model_action(next_state)\n",
    "\n",
    "        #计算三份动作分别的value\n",
    "        #[b*5, 1],[b*5, 1] -> [b*5, 1] -> [b, 5, 1]\n",
    "        value_rand = model_value(state, rand_action).reshape(-1, 5, 1)\n",
    "        #[b*5, 1],[b*5, 1] -> [b*5, 1] -> [b, 5, 1]\n",
    "        value_curr = model_value(state, curr_action).reshape(-1, 5, 1)\n",
    "        #[b*5, 1],[b*5, 1] -> [b*5, 1] -> [b, 5, 1]\n",
    "        value_next = model_value(state, next_action).reshape(-1, 5, 1)\n",
    "\n",
    "        #[b*5, 1] -> [b, 5, 1]\n",
    "        curr_entropy = curr_entropy.detach().reshape(-1, 5, 1)\n",
    "        next_entropy = next_entropy.detach().reshape(-1, 5, 1)\n",
    "\n",
    "        #三份value分别减去他们的熵\n",
    "        #[b, 5, 1]\n",
    "        value_rand -= math.log(0.5)\n",
    "        #[b, 5, 1]\n",
    "        value_curr -= curr_entropy\n",
    "        #[b, 5, 1]\n",
    "        value_next -= next_entropy\n",
    "\n",
    "        #拼合三份value\n",
    "        #[b, 5+5+5, 1] -> [b, 15, 1]\n",
    "        value_cat = torch.cat([value_rand, value_curr, value_next], dim=1)\n",
    "\n",
    "        #等价t.logsumexp(dim=1), t.exp().sum(dim=1).log()\n",
    "        #[b, 15, 1] -> [b, 1] -> scala\n",
    "        loss_cat = torch.logsumexp(value_cat, dim=1).mean()\n",
    "\n",
    "        #在原本的loss上增加上这一部分\n",
    "        #scala\n",
    "        loss_value += 5.0 * (loss_cat - value.mean())\n",
    "        \"\"\"CQL算法和SCA算法的差异到此为止\"\"\"\n",
    "\n",
    "        return loss_value\n",
    "\n",
    "\n",
    "student = CQL()\n",
    "\n",
    "student.train(\n",
    "    torch.randn(5, 3),\n",
    "    torch.randn(5, 1),\n",
    "    torch.randn(5, 1),\n",
    "    torch.randn(5, 3),\n",
    "    torch.zeros(5, 1).long(),\n",
    ")\n",
    "\n",
    "student.get_action([1, 2, 3]), student.test(play=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "OHoSU6uI-xIt",
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 -1443.422561044726\n",
      "2000 -599.6128910675469\n",
      "4000 -664.1286914371211\n",
      "6000 -776.7847175744591\n",
      "8000 -520.0376657744876\n",
      "10000 -860.4838950039799\n",
      "12000 -874.1900215606678\n",
      "14000 -817.4094007115924\n",
      "16000 -512.2919691849098\n",
      "18000 -661.9038875433434\n",
      "20000 -309.70976705013845\n",
      "22000 -458.29987190239626\n",
      "24000 -443.9337041434702\n",
      "26000 -314.6220535432918\n",
      "28000 -965.9482132045905\n",
      "30000 -426.22387888772545\n",
      "32000 -383.35949721284805\n",
      "34000 -839.9712252370849\n",
      "36000 -545.2773120223367\n",
      "38000 -453.324547373389\n",
      "40000 -509.40048541797876\n",
      "42000 -782.4016093068901\n",
      "44000 -953.1805229102653\n",
      "46000 -692.8802340078068\n",
      "48000 -1084.8934011465312\n"
     ]
    }
   ],
   "source": [
    "#训练N次,训练过程中不需要更新数据\n",
    "for i in range(50000):\n",
    "    #采样一批数据\n",
    "    student.train(*data.get_sample())\n",
    "\n",
    "    if i % 2000 == 0:\n",
    "        test_result = sum([student.test(play=False) for _ in range(10)]) / 10\n",
    "        print(i, test_result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGiCAYAAABd6zmYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAlNUlEQVR4nO3df3DU9YH/8ddns8mSX7shQBIoiXCCYg6hFgT27NWr5Ig2/WGlN+oxHlpGRy44IjfOyZ3i1LkZGNuerXeInbYnTq9Kh7ZYpWAvBxj0iAEjOcMP0Z5oUmATfmU3iWTzY9/fPyz7dTHWhHyy+97wfMzsjPl83vvO+/MR8iS7n911jDFGAABYyJPqBQAA8GmIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWimL1Pr16zVlyhSNGTNG8+fP1969e1O1FACApVISqV/84hdatWqVHn30Ub355puaPXu2Kisr1dbWlorlAAAs5aTiDWbnz5+va6+9Vv/+7/8uSYrFYiotLdV9992nhx56KNnLAQBYypvsb9jT06OGhgatXr06vs3j8aiiokJ1dXUD3icajSoajca/jsViOnPmjMaNGyfHcUZ8zQAAdxlj1NHRoUmTJsnj+fQH9ZIeqVOnTqm/v1/FxcUJ24uLi/X2228PeJ+1a9fqO9/5TjKWBwBIopaWFk2ePPlT9yc9Uhdj9erVWrVqVfzrcDissrIytbS0yO/3p3BlAICLEYlEVFpaqvz8/D85LumRGj9+vDIyMtTa2pqwvbW1VSUlJQPex+fzyefzfWK73+8nUgCQxj7rKZukX92XlZWlOXPmaMeOHfFtsVhMO3bsUDAYTPZyAAAWS8nDfatWrdLSpUs1d+5czZs3Tz/4wQ/U1dWlu+66KxXLAQBYKiWRuvXWW3Xy5EmtWbNGoVBIn//85/Xyyy9/4mIKAMClLSWvkxquSCSiQCCgcDjMc1IAkIYG+3Oc9+4DAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYK0hR2r37t362te+pkmTJslxHL3wwgsJ+40xWrNmjSZOnKjs7GxVVFTo3XffTRhz5swZLVmyRH6/XwUFBVq2bJk6OzuHdSAAgNFnyJHq6urS7NmztX79+gH3P/7443ryySf19NNPq76+Xrm5uaqsrFR3d3d8zJIlS3Tw4EHV1NRo69at2r17t+65556LPwoAwOhkhkGS2bJlS/zrWCxmSkpKzHe/+934tvb2duPz+czzzz9vjDHm0KFDRpLZt29ffMz27duN4zjm2LFjg/q+4XDYSDLhcHg4ywcApMhgf467+pzU0aNHFQqFVFFREd8WCAQ0f/581dXVSZLq6upUUFCguXPnxsdUVFTI4/Govr5+wHmj0agikUjCDQAw+rkaqVAoJEkqLi5O2F5cXBzfFwqFVFRUlLDf6/WqsLAwPuZCa9euVSAQiN9KS0vdXDYAwFJpcXXf6tWrFQ6H47eWlpZULwkAkASuRqqkpESS1NramrC9tbU1vq+kpERtbW0J+/v6+nTmzJn4mAv5fD75/f6EGwBg9HM1UlOnTlVJSYl27NgR3xaJRFRfX69gMChJCgaDam9vV0NDQ3zMzp07FYvFNH/+fDeXAwBIc96h3qGzs1O///3v418fPXpUjY2NKiwsVFlZmVauXKl/+Zd/0fTp0zV16lQ98sgjmjRpkm6++WZJ0lVXXaUbb7xRd999t55++mn19vZqxYoVuu222zRp0iTXDgwAMAoM9bLBXbt2GUmfuC1dutQY89Fl6I888ogpLi42Pp/PLFy40Bw5ciRhjtOnT5vbb7/d5OXlGb/fb+666y7T0dHh+qWLAAA7DfbnuGOMMSls5EWJRCIKBAIKh8M8PwUAaWiwP8fT4uo+AMCliUgBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1hryu6AD+GzGGJm+Pn149Kja6+p07oMP1N/ZqYy8PGWXlqpgwQLlTJsmx+uV4zipXi5gLSIFuMzEYup69121vfiiwnv3KtbTI33sfZwjDQ1q++1vFZg7V8Vf/7pyZ8yQ4+FBDWAgRApwWUdTk1p+/GN1NzcPPMAYmZ4ete/Zo3MffKDSZcsUmDs3uYsE0gT/fANcYoxRx6FD+sNPf/rpgbpA9Ngxtfz0p4o0NSkNPzUHGHFECnBJ76lTOvHcczr3/vtDul/02DGd+PnP1dPaOjILA9IYkQJcYIxR5+HD6njrrYu6f+ehQ+poapKJxVxeGZDeiBTgAtPXpz9s3DisOY797GfqOXXKnQUBowSRAlzSF4kM7/7t7Xpv3TrFentdWhGQ/ogUYJFYT4/6wuFULwOwBpECLNL/4YeKcgEFEEekAIv0njqlSGNjqpcBWINIAQCsRaQAFzgej8bdcIM7kxnDC3uBPyJSgBscR2MmT3Zlqv4PP5ThCj9AEpEC3OE4ypowwZWp+sJhxaJRV+YC0h2RAlzizc93ZZ5oKKS+zk5X5gLSHZECXOA4juTS50J9+O676j1zxpW5gHRHpACXuBkqAB8hUoBLfBMnuva5UKavjyv8ABEpwDWeMWOUOXasK3P1nDzpyjxAuiNSgEs8Pp+8gYArc3UfPy7xsR0AkQJc4zhyPO78lTr50ksy/f2uzAWkMyIFuMRx8aIJno8CPkKkABflXnWVvC49LwWASAGu8hUXKyMnZ/gTGcMLegERKcBV3kBAHp9v+BMZwxV+gIgU4KqM3Fx5vN5hz2OMUXdzswsrAtIbkQJc5Nq7TvT368wrrwx/HiDNESkAgLWIFOCy8Tfe6MpvU8YYGV7Qi0sckQJcNqa01JVIxaJR9XV0uLAiIH0RKcBlvqIi9yIVibiwIiB9ESnAZR6fT26890R3S4vC+/a5MBOQvogUYCtjeP8+XPKIFOAyJyND+Z//vGvz8T5+uJQRKcBlTkaGcqdPd2WuWE8PH9mBSxqRAtzmOMqaMMGVqfrOnv0oVMAlikgBbnMc9z6h98wZxaJRV+YC0hGRAlzmOI7k0ocfRhoa1HPqlCtzAemISAE246IJXOKIFDACfCUlyrniCncmM4Yr/HDJIlLACPDm53/0zhMu6Dl92pV5gHREpIAR4MnOljc/35W5etraXJkHSEdEChgBTkaGnKwsV+Zqe/FFXiuFSxaRAkaAax9+KKmvs9OVeYB0RKSAEZJ92WXy5OSkehlAWiNSwAjJnjJFGW5EyhjedQKXLCIFjJDMwkJ53Hheyhj1nDw5/HmANESkgBGS6ffLycwc9jwmFiNSuGQRKWCEOBkZrnz4oenp0amaGhdmAtIPkQLSgOntTfUSgJQgUsAICsyf79pcvDUSLkVDitTatWt17bXXKj8/X0VFRbr55pt15MiRhDHd3d2qrq7WuHHjlJeXp8WLF6u1tTVhTHNzs6qqqpSTk6OioiI9+OCD6uvrG/7RAJbJKy93ZZ5YNKrYuXOuzAWkkyFFqra2VtXV1Xr99ddVU1Oj3t5eLVq0SF1dXfExDzzwgF566SVt3rxZtbW1On78uG655Zb4/v7+flVVVamnp0d79uzRs88+q40bN2rNmjXuHRVgCV9JiSsv6u0/d059HR0urAhIL44ZxmMIJ0+eVFFRkWpra/WlL31J4XBYEyZM0HPPPadvfetbkqS3335bV111lerq6rRgwQJt375dX/3qV3X8+HEVFxdLkp5++mn94z/+o06ePKmsQVyyG4lEFAgEFA6H5ff7L3b5wIjrjUT01h13DPsjN3yf+5ymrlrl2sfSA6k22J/jw3pOKhwOS5IKCwslSQ0NDert7VVFRUV8zIwZM1RWVqa6ujpJUl1dna6++up4oCSpsrJSkUhEBw8eHPD7RKNRRSKRhBuQDhyX3hopeuyYui54aB24FFx0pGKxmFauXKnrrrtOM2fOlCSFQiFlZWWpoKAgYWxxcbFCoVB8zMcDdX7/+X0DWbt2rQKBQPxWWlp6scsG0hcXTuASdNGRqq6u1oEDB7Rp0yY31zOg1atXKxwOx28tLS0j/j0BN3iysjT2L//SlblMfz9X+OGSc1GRWrFihbZu3apdu3Zp8uTJ8e0lJSXq6elRe3t7wvjW1laVlJTEx1x4td/5r8+PuZDP55Pf70+4AenA8XqV82d/5spcvWfP8nopXHKGFCljjFasWKEtW7Zo586dmjp1asL+OXPmKDMzUzt27IhvO3LkiJqbmxUMBiVJwWBQTU1NavvYB7nV1NTI7/er3KXLdQFreDzK/ONztsPVe/asYkQKlxjvUAZXV1frueee029+8xvl5+fHn0MKBALKzs5WIBDQsmXLtGrVKhUWFsrv9+u+++5TMBjUggULJEmLFi1SeXm57rjjDj3++OMKhUJ6+OGHVV1dLZ/P5/4RAinkOI47bzIrqb2+XhNvu03e3FxX5gPSwZAitWHDBknSX/3VXyVsf+aZZ3TnnXdKkp544gl5PB4tXrxY0WhUlZWVeuqpp+JjMzIytHXrVi1fvlzBYFC5ublaunSpHnvsseEdCTDKxc6dk+nvT/UygKQa1uukUoXXSSGddL3zjo5+//uKnjgx7LnK169XNle3YhRIyuukAHy2zPHj5Zs40ZW5+jo7ucIPlxQiBYywjJwcefPzXZmr51NeSwiMVkQKGGEen0+eMWNcmSt6wcs3gNGOSAEjzHEcV95kVpJO8+GHuMQQKSCNxHp6Ur0EIKmIFJAE+bNmKSMnx53JYjF35gHSAJECkiDn8svlyc4e9jymv1+9Z8+6sCIgPRApIAmyCgvleIf02vkBmVhMPSdPurAiID0QKSAJPD6fHM/w/7rFurt1ds8eF1YEpAciBaSTWEy9Z86kehVA0hApIEly+Oh3YMiIFJAkgblzXZnH9PXxkR24ZBApIEl8kya5Mk9/V5f6OztdmQuwHZECkiSzoMCVefq7utRHpHCJIFJAkjgZGa7Mc665WR++954rcwG2I1JAmjG9vTLRaKqXASQFkQKSxOPzKXfGDFfmMhKfK4VLwvBfAg9gUDw+n/L+/M/V9fbbnzm23xj9+v339Wprq4qzs/XnBQX6fGGhirOzNSYjQ32RyEfv4efSQ4iArYgUkCRORoayJkwY1FiPpFsuu0w3TZ6s0Llzajp7Vr84elSBrCxdV1SkgtZWxfr6lEGkMMoRKSBJHI9HGYP88EPHcZThOMrzeDQtM1OX5+erq69PTWfP6uVjx3TspZe08rbblOvzjfCqgdTiOSkgDTiOo7zMTC2YMEFLp03T//3f/2njf/6nolxAgVGOSAFJ5A0E5PX7L/r+juOoKDtbd02frvcOHNAvf/lL9fLuExjFiBSQRL7PfU6+z31u2PPkeL265847tXfvXh06dIgr/TBqESkgibx5efLm57syV5HPpxtuuEEvv/wyD/th1CJSQBJlZGcr4zM+ofdYV5e2trTo+ffe038fP66uT3k4r7e1Vdddd50++OADneSDEDFKcXUfkCTGGHV++KG6e3o+df/Rzk49un+/3u/sVHd/v/yZmZo5dqy+d+21yrzgQxPPvvqqrvjKV3TllVdq3759Ki0tTcZhAEnFb1JAkpw7d04//vGP9VZT04DPIb3X2am7/+d/dDgc1rn+fhlJ4d5e/U9bm+6vr9fp7u6E8ec/Rv5LX/qSXn311WQcApB0RApIkvb2dv3oRz/SgbNnZQZ4Ee4PDh5U+FMe2tt76pRqjh8fcN+0adN09OhRV9cK2IJIAUmyd+9effDBB3o9EtE5F6/Gy83NVX9/v2vzATbhOSkgSU6cOKE777xTGZK6jx9Xrotzcwk6RisiBSTJ7bffrv7+fl1xxRW64uRJnf7FLxL2V5WW6o1Tp9Q7QHCm5OVpVmHhgPN2dXXJ6+WvMkYnHu4DksTv98vv9+v06dOafPPNKvzylxP2V06apEevuUZjMjLifzEzHEfjfD59/9prVX7BJ/sWL14sSfr973+vyy+/PAlHACQf//wCksRxHF122WWqr69Xj+NobDCo8BtvqL+jI76/ctIkTc7J0dY//EGnu7s1JS9Pt06dqnEXvJGsb+JEjQ0GJUm7d+/WF7/4xaQfD5AMRApIEsdxNGPGDP3qV7/S2bNnNfGaazThppvU+qtfyfzxwgfHcTRz7FjNHDv2U+fxjh2rSXfcIe8ffyt75513tPiPv1UBow0P9wFJNGHCBM2ePVsvvPCClJmp4m9+U4U33CBnkM8pZeTlaeKtt6pg3jwZx9Frr72msrIyjR8/fmQXDqQIkQKSyOPx6NZbb9Wbb76p+vp6ZeTkaPK3v63ixYuVVVT0qfdzMjKUPWWKSu++WxNuukmerCy1tLRo165duvHGG+Xjc6UwSvFwH5Bkubm5uueee/Szn/1MEyZM0OWXX66J3/qW/LNm6eyePeo8eFDRUEixaFQZeXkaM3myAnPnKjB3rrLLyuQ4jjo7O/XUU09p3rx5Ki8vl+M4qT4sYEQQKSDJHMfRnDlz1Nraqu9///tatWqVpk6dqryZM5V7xRXqP3dOpq9PJhaTk5EhT2amPDk58ni9isVi+kNLi9avX68pU6bob/7mb5SZmZnqQwJGDJECUsDr9eqrX/2qsrOz9b3vfU/XX3+9vvzlL6ukpESZAzx0Z4xROBxWXV2dtm3bpmuuuUZ/+7d/q6ysrBSsHkgex6ThS9UjkYgCgYDC4bD8w/iUUyDVjDE6dOiQtm/frsOHD2vmzJkKBoOaNm2a/H6/zp07p+bmZu3Zs0eNjY0qKipSVVWVZs+ezfNQSGuD/TlOpIAUM8aop6dHp0+f1u7du1VfX6/33ntPnZ2dGjNmjCZPnqx58+bpi1/8okpLS5Wdnc1zUEh7RAoAYK3B/hznEnQAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYa0iR2rBhg2bNmiW/3y+/369gMKjt27fH93d3d6u6ulrjxo1TXl6eFi9erNbW1oQ5mpubVVVVpZycHBUVFenBBx9UX1+fO0cDABhVhhSpyZMna926dWpoaNAbb7yhG264Qd/4xjd08OBBSdIDDzygl156SZs3b1Ztba2OHz+uW265JX7//v5+VVVVqaenR3v27NGzzz6rjRs3as2aNe4eFQBgdDDDNHbsWPOTn/zEtLe3m8zMTLN58+b4vsOHDxtJpq6uzhhjzLZt24zH4zGhUCg+ZsOGDcbv95toNDro7xkOh40kEw6Hh7t8AEAKDPbn+EU/J9Xf369Nmzapq6tLwWBQDQ0N6u3tVUVFRXzMjBkzVFZWprq6OklSXV2drr76ahUXF8fHVFZWKhKJxH8bG0g0GlUkEkm4AQBGvyFHqqmpSXl5efL5fLr33nu1ZcsWlZeXKxQKKSsrSwUFBQnji4uLFQqFJEmhUCghUOf3n9/3adauXatAIBC/lZaWDnXZAIA0NORIXXnllWpsbFR9fb2WL1+upUuX6tChQyOxtrjVq1crHA7Hby0tLSP6/QAAdvAO9Q5ZWVmaNm2aJGnOnDnat2+ffvjDH+rWW29VT0+P2tvbE36bam1tVUlJiSSppKREe/fuTZjv/NV/58cMxOfzyefzDXWpAIA0N+zXScViMUWjUc2ZM0eZmZnasWNHfN+RI0fU3NysYDAoSQoGg2pqalJbW1t8TE1Njfx+v8rLy4e7FADAKDOk36RWr16tm266SWVlZero6NBzzz2nV155Rb/73e8UCAS0bNkyrVq1SoWFhfL7/brvvvsUDAa1YMECSdKiRYtUXl6uO+64Q48//rhCoZAefvhhVVdX85sSAOAThhSptrY2/d3f/Z1OnDihQCCgWbNm6Xe/+53++q//WpL0xBNPyOPxaPHixYpGo6qsrNRTTz0Vv39GRoa2bt2q5cuXKxgMKjc3V0uXLtVjjz3m7lEBAEYFxxhjUr2IoYpEIgoEAgqHw/L7/aleDgBgiAb7c5z37gMAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgrWFFat26dXIcRytXroxv6+7uVnV1tcaNG6e8vDwtXrxYra2tCfdrbm5WVVWVcnJyVFRUpAcffFB9fX3DWQoAYBS66Ejt27dPP/rRjzRr1qyE7Q888IBeeuklbd68WbW1tTp+/LhuueWW+P7+/n5VVVWpp6dHe/bs0bPPPquNGzdqzZo1F38UAIDRyVyEjo4OM336dFNTU2Ouv/56c//99xtjjGlvbzeZmZlm8+bN8bGHDx82kkxdXZ0xxpht27YZj8djQqFQfMyGDRuM3+830Wh0UN8/HA4bSSYcDl/M8gEAKTbYn+MX9ZtUdXW1qqqqVFFRkbC9oaFBvb29CdtnzJihsrIy1dXVSZLq6up09dVXq7i4OD6msrJSkUhEBw8eHPD7RaNRRSKRhBsAYPTzDvUOmzZt0ptvvql9+/Z9Yl8oFFJWVpYKCgoSthcXFysUCsXHfDxQ5/ef3zeQtWvX6jvf+c5QlwoASHND+k2qpaVF999/v37+859rzJgxI7WmT1i9erXC4XD81tLSkrTvDQBInSFFqqGhQW1tbfrCF74gr9crr9er2tpaPfnkk/J6vSouLlZPT4/a29sT7tfa2qqSkhJJUklJySeu9jv/9fkxF/L5fPL7/Qk3AMDoN6RILVy4UE1NTWpsbIzf5s6dqyVLlsT/OzMzUzt27Ijf58iRI2publYwGJQkBYNBNTU1qa2tLT6mpqZGfr9f5eXlLh0WAGA0GNJzUvn5+Zo5c2bCttzcXI0bNy6+fdmyZVq1apUKCwvl9/t13333KRgMasGCBZKkRYsWqby8XHfccYcef/xxhUIhPfzww6qurpbP53PpsAAAo8GQL5z4LE888YQ8Ho8WL16saDSqyspKPfXUU/H9GRkZ2rp1q5YvX65gMKjc3FwtXbpUjz32mNtLAQCkOccYY1K9iKGKRCIKBAIKh8M8PwUAaWiwP8d57z4AgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLW8qV7AxTDGSJIikUiKVwIAuBjnf36f/3n+adIyUqdPn5YklZaWpnglAIDh6OjoUCAQ+NT9aRmpwsJCSVJzc/OfPLhLXSQSUWlpqVpaWuT3+1O9HGtxngaH8zQ4nKfBMcaoo6NDkyZN+pPj0jJSHs9HT6UFAgH+EAyC3+/nPA0C52lwOE+Dw3n6bIP5JYMLJwAA1iJSAABrpWWkfD6fHn30Ufl8vlQvxWqcp8HhPA0O52lwOE/ucsxnXf8HAECKpOVvUgCASwORAgBYi0gBAKxFpAAA1krLSK1fv15TpkzRmDFjNH/+fO3duzfVS0qq3bt362tf+5omTZokx3H0wgsvJOw3xmjNmjWaOHGisrOzVVFRoXfffTdhzJkzZ7RkyRL5/X4VFBRo2bJl6uzsTOJRjKy1a9fq2muvVX5+voqKinTzzTfryJEjCWO6u7tVXV2tcePGKS8vT4sXL1Zra2vCmObmZlVVVSknJ0dFRUV68MEH1dfXl8xDGVEbNmzQrFmz4i88DQaD2r59e3w/52hg69atk+M4WrlyZXwb52qEmDSzadMmk5WVZf7jP/7DHDx40Nx9992moKDAtLa2pnppSbNt2zbzz//8z+bXv/61kWS2bNmSsH/dunUmEAiYF154wfzv//6v+frXv26mTp1qzp07Fx9z4403mtmzZ5vXX3/dvPrqq2batGnm9ttvT/KRjJzKykrzzDPPmAMHDpjGxkbzla98xZSVlZnOzs74mHvvvdeUlpaaHTt2mDfeeMMsWLDA/MVf/EV8f19fn5k5c6apqKgw+/fvN9u2bTPjx483q1evTsUhjYgXX3zR/Pa3vzXvvPOOOXLkiPmnf/onk5mZaQ4cOGCM4RwNZO/evWbKlClm1qxZ5v77749v51yNjLSL1Lx580x1dXX86/7+fjNp0iSzdu3aFK4qdS6MVCwWMyUlJea73/1ufFt7e7vx+Xzm+eefN8YYc+jQISPJ7Nu3Lz5m+/btxnEcc+zYsaStPZna2tqMJFNbW2uM+eicZGZmms2bN8fHHD582EgydXV1xpiP/jHg8XhMKBSKj9mwYYPx+/0mGo0m9wCSaOzYseYnP/kJ52gAHR0dZvr06aampsZcf/318UhxrkZOWj3c19PTo4aGBlVUVMS3eTweVVRUqK6uLoUrs8fRo0cVCoUSzlEgEND8+fPj56iurk4FBQWaO3dufExFRYU8Ho/q6+uTvuZkCIfDkv7/mxM3NDSot7c34TzNmDFDZWVlCefp6quvVnFxcXxMZWWlIpGIDh48mMTVJ0d/f782bdqkrq4uBYNBztEAqqurVVVVlXBOJP48jaS0eoPZU6dOqb+/P+F/siQVFxfr7bffTtGq7BIKhSRpwHN0fl8oFFJRUVHCfq/Xq8LCwviY0SQWi2nlypW67rrrNHPmTEkfnYOsrCwVFBQkjL3wPA10Hs/vGy2ampoUDAbV3d2tvLw8bdmyReXl5WpsbOQcfcymTZv05ptvat++fZ/Yx5+nkZNWkQIuRnV1tQ4cOKDXXnst1Uux0pVXXqnGxkaFw2H98pe/1NKlS1VbW5vqZVmlpaVF999/v2pqajRmzJhUL+eSklYP940fP14ZGRmfuGKmtbVVJSUlKVqVXc6fhz91jkpKStTW1pawv6+vT2fOnBl153HFihXaunWrdu3apcmTJ8e3l5SUqKenR+3t7QnjLzxPA53H8/tGi6ysLE2bNk1z5szR2rVrNXv2bP3whz/kHH1MQ0OD2tra9IUvfEFer1der1e1tbV68skn5fV6VVxczLkaIWkVqaysLM2ZM0c7duyIb4vFYtqxY4eCwWAKV2aPqVOnqqSkJOEcRSIR1dfXx89RMBhUe3u7Ghoa4mN27typWCym+fPnJ33NI8EYoxUrVmjLli3auXOnpk6dmrB/zpw5yszMTDhPR44cUXNzc8J5ampqSgh6TU2N/H6/ysvLk3MgKRCLxRSNRjlHH7Nw4UI1NTWpsbExfps7d66WLFkS/2/O1QhJ9ZUbQ7Vp0ybj8/nMxo0bzaFDh8w999xjCgoKEq6YGe06OjrM/v37zf79+40k86//+q9m//795oMPPjDGfHQJekFBgfnNb35j3nrrLfONb3xjwEvQr7nmGlNfX29ee+01M3369FF1Cfry5ctNIBAwr7zyijlx4kT89uGHH8bH3HvvvaasrMzs3LnTvPHGGyYYDJpgMBjff/6S4UWLFpnGxkbz8ssvmwkTJoyqS4YfeughU1tba44ePWreeust89BDDxnHccx//dd/GWM4R3/Kx6/uM4ZzNVLSLlLGGPNv//ZvpqyszGRlZZl58+aZ119/PdVLSqpdu3YZSZ+4LV261Bjz0WXojzzyiCkuLjY+n88sXLjQHDlyJGGO06dPm9tvv93k5eUZv99v7rrrLtPR0ZGCoxkZA50fSeaZZ56Jjzl37pz5+7//ezN27FiTk5NjvvnNb5oTJ04kzPP++++bm266yWRnZ5vx48ebf/iHfzC9vb1JPpqR8+1vf9tcdtllJisry0yYMMEsXLgwHihjOEd/yoWR4lyNDD6qAwBgrbR6TgoAcGkhUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFr/D9PmsLQCkQJvAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "-119.65106814313626"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "student.test(play=True)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第7章-DQN算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
